import re
import cv2
import numpy as np
import gymnasium as gym
from Environment.environment import Environment, strip_instance, non_state_factors
from Environment.Environments.Box2D.box2d_init_specs import *
from Environment.Environments.Box2D.box2d_specs import *
from Environment.Environments.Box2D.box2d_init_environment import init_box2d_environment

class Action():
    def __init__(self, continuous):
        self.name = "Action"
        self.attribute = np.zeros(2) if continuous else 0
        self.continuous = continuous
        self.interaction_trace = list()

    def take_action(self, action):
        self.attribute = action
    
    def get_state(self):
        return np.array(self.attribute) if self.continuous else np.array([self.attribute])

class Goal():
    def __init__(self, form, target_form, all_names, bounds, goal_epsilon):
        self.name = "Goal"
        self.form = form
        self.target_form = target_form
        self.all_names = all_names
        self.target_idx = self.all_names.index("Target") if "Target" in self.all_names else -1
        self.control_idx = self.all_names.index("Control")
        self.target_graph_idx = [n for n in self.all_names if n not in non_state_factors].index("Target") if "Target" in self.all_names else -1
        self.control_graph_idx = [n for n in self.all_names if n not in non_state_factors].index("Control")
        self.idx = self.control_idx if self.form == "Control" else (-1 if self.form in ["Full", "FullVel"] else self.target_idx)
        self.graph_idx = self.control_graph_idx if self.form == "Control" else (-1 if self.form in ["Full", "FullVel"] else self.target_graph_idx)
        self.bounds = np.array(bounds).astype(float) / 2
        self.attribute = np.zeros(1) # wrong dimensions until sample_goal is called
        self.goal_epsilon = goal_epsilon
        self.interaction_trace = list()
        self.goal_range = self.generate_bounds()[0]
    
    def normalize_goal(self, goal):
        return goal / self.goal_range

    def generate_bounds(self):
        if self.form in ["Target", "Control"]:
            return self.bounds.copy(), np.ones(2)
        elif self.form == "TargetPolyAng":
            return np.concatenate([self.bounds.copy(), np.ones(2)]), np.array([1,1,0,0])
        elif self.form in ["TargetBallVel", "ControlVel"]:
            return np.concatenate([self.bounds.copy(), self.bounds.copy()]), np.array([1,1,0,0])
        elif self.form == "TargetPolyVel":
            return np.concatenate([self.bounds.copy(), self.bounds.copy(), np.ones(2), np.pi * 2]), np.array([1,1,0,0,0,0,0])
        elif self.form == "Full":
            all_bounds, all_pos_masks = list(), list()
            for n in self.all_names:
                if n not in ["Action", "Reward", "Done", "Goal"]:
                    all_bounds.append(self.bounds.copy())
                    all_pos_masks.append(np.ones(2))
            all_bounds = np.concatenate(all_bounds, axis=0)
            return all_bounds, all_pos_masks
        elif self.form == "FullVel":
            all_bounds, all_pos_masks = list(), list()
            for n in self.all_names:
                if n.find("Ball") != -1 or n.find("Control") != -1 or (n.find("Target") != -1 and self.target_form.find("Ball") != -1):
                    all_bounds.append(np.concatenate([self.bounds.copy(), self.bounds.copy()]))
                    all_pos_masks.append(np.array([1,1,0,0]))
                elif n.find("Poly") != -1 or (n.find("Target") != -1 and self.target_form.find("Poly") != -1):
                    all_bounds.append(np.concatenate([self.bounds.copy(), self.bounds.copy(), np.ones(2), np.pi * 2]))
                    all_pos_masks.append(np.array([1,1,0,0,0,0,0]))
            all_bounds = np.concatenate(all_bounds, axis=0)
            return all_bounds, all_pos_masks

    def sample_angle(self):
        ang = np.random.rand() * np.pi * 2
        return np.array([np.sin(ang), np.cos(ang)])

    def sample_ball_pos(self):
        # print("NEW SAMPLE", 0.5 * 2 * self.bounds, -0.5 * 2 * self.bounds)
        return (np.random.rand(2) - 0.5) * 2 * self.bounds

    def sample_pos_ang(self):
        return np.concatenate([self.sample_ball_pos(), self.sample_angle()])

    def sample_vel(self):
        return (np.random.rand(4) - 0.5) * 2 * np.concatenate([self.bounds, self.bounds])

    def sample_angular_vel(self):
        return np.array([np.random.rand() * 4 * np.pi - (np.pi * 2)])

    def sample_poly(self):
        return np.concatenate([self.sample_ball_pos(), self.sample_vel(), self.sample_angle(), self.sample_angular_vel()])

    def sample_goal(self):
        if self.form in ["Target", "Control"]:
            samples = self.sample_ball_pos()
        elif self.form == "TargetPolyAng":
            samples = self.sample_pos_ang()
        elif self.form in ["TargetBallVel", "ControlVel"]:
            samples = self.sample_vel()
        elif self.form == "TargetPolyVel":
            samples = self.sample_poly()
        elif self.form == "Full":
            samples = list()
            for n in self.all_names:
                if n not in ["Action", "Reward", "Done", "Goal"]:
                    samples.append(self.sample_ball_pos())
            samples = np.concatenate(samples, axis=0)
        elif self.form == "FullVel":
            samples = list()
            for n in self.all_names:
                if n not in ["Action", "Reward", "Done", "Goal"]:
                    if n.find("Ball") != -1 or n.find("Control") != -1 or (n.find("Target") != -1 and self.target_form.find("Ball") != -1):
                        samples.append(self.sample_vel())
                    elif n.find("Poly") != -1 or (n.find("Target") != -1 and self.target_form.find("Poly") != -1):
                        samples.append(self.sample_poly())
            samples = np.concatenate(samples, axis=0)
        self.attribute = samples
        return self.attribute
    
    def get_achieved_goal(self, env):
        longest = max([len(env.object_name_dict[n].get_state()) for n in self.all_names])
        state = np.stack([np.pad(env.object_name_dict[n].get_state(), (0,longest - env.object_name_dict[n].get_state().shape[0])) for n in self.all_names], axis=0)
        return self.get_achieved_goal_state(state)

    def get_achieved_goal_state(self, object_state, fidx=None):
        # TODO: implement so that fidx is used
        if self.form == "Target":
            return object_state[...,self.target_idx,:2]
        elif self.form == "Control":
            return object_state[...,self.control_idx,:2]
        elif self.form == "TargetPolyAng":
            return np.concatenate([object_state[...,self.target_idx,:2], object_state[...,self.target_idx,4:6]])
        elif self.form == "TargetBallVel":
            return object_state[...,self.target_idx,:4]
        elif self.form == "ControlVel":
            return object_state[...,self.control_idx,:4]
        elif self.form == "TargetPolyVel":
            return object_state[...,self.target_idx,:7]
        elif self.form == "Full":
            return np.concatenate([object_state[...,i,:2] for i in range(object_state.shape[-2]) if i not in [0,len(self.all_names)-2,len(self.all_names)-1 ]])
        elif self.form == "FullVel":
            return np.concatenate([object_state[...,i] for i in range(object_state.shape[-2]) if i not in [0,len(self.all_names)-2,len(self.all_names)-1 ]])

    def add_interaction(self, reached_goal):
        if reached_goal:
            if self.form in ["Target", "TargetPolyAng", "TargetBallVel", "TargetPolyVel"]:
                self.interaction_trace += ["Target"]
            elif self.form == "Full":
                self.interaction_trace += [n for n in self.all_names if n not in non_state_factors]


    def get_state(self):
        return self.attribute # np.array([self.goal_epsilon])

    def check_goal(self, env):
        # returns True if all dimensions are less than epsilon
        # print(self.normalize_goal(self.get_achieved_goal(env)), self.normalize_goal(self.attribute), np.linalg.norm(self.normalize_goal(self.get_achieved_goal(env)) - self.normalize_goal(self.attribute), axis=-1), self.goal_epsilon)
        # return np.linalg.norm(self.normalize_goal(self.get_achieved_goal(env)) - self.normalize_goal(self.attribute), axis=-1) < self.goal_epsilon
        # print(np.linalg.norm(self.get_achieved_goal(env) - self.attribute, axis=-1))
        return np.linalg.norm(self.get_achieved_goal(env) - self.attribute, axis=-1) < self.goal_epsilon


class Box2DObjWrapper():
    def __init__(self, name, box2dobj):
        self.name = name
        self.obj = box2dobj
        self.interaction_trace = list()
    
    def get_state(self):
        if self.name.find("Control") != -1 or self.name.find("Target") != -1 or self.name.find("Ball") != -1:
            if self.obj is None: return np.zeros(5)
            pos = np.array(self.obj.position)
            radius = float(self.obj.fixtures[0].shape.radius) # there should only be one fixture
            vel = np.array(self.obj.linearVelocity)
            return np.array(pos.tolist() + vel.tolist() + [radius])
        elif self.name.find("Poly") != -1:
            if self.obj is None: return np.zeros(8)
            pos = np.array(self.obj.position)
            radius = preassigned_radius[preassigned_nid[strip_instance(self.name)]][1] # TODO: assumes there are only in single digits instances of objects, will error if double digit
            vel = np.array(self.obj.linearVelocity)
            angle = float(self.obj.angle)
            sin_angle = np.sin(angle)
            cos_angle = np.cos(angle)
            avel = float(self.obj.angularVelocity)
            return np.array(pos.tolist() + vel.tolist() + [sin_angle, cos_angle, avel, radius])
        # Should error if a different name is found


class Box2DEnvironment(Environment):
    def __init__(self, frameskip = 1, horizon=200, variant="", fixed_limits=False, renderable=False, render_masks=False):
        ''' required attributes:
            num actions: int or None
            action_space: gym.Spaces
            action_shape: tuple of ints
            observation_space = gym.Spaces
            done: boolean
            reward: int
            seed_counter: int
            discrete_actions: boolean
            name: string
        All the below properties are set by the subclass
        '''
        super().__init__(frameskip=frameskip, variant=variant, fixed_limits=fixed_limits)
        # environment properties
        self.continuous_actions, self.num_balls, self.poly_ids, self.num_poly_inst, \
            self.dyn_damping, self.density, self.init_velocity,self.general_radius_min, self.general_radius_max, \
            self.use_target, self.target_form, self.target_damping, self.target_radius, \
            self.target_density, self.target_init_vel, self.control_damping, self.control_radius, self.control_density, \
             self.grav_domain, self.phyre_path, self.length, self.width, self.force_scaling, self.render_size,\
             self.min_valid, self.max_valid, self.force_live, \
             self.max_steps, self.goal_epsilon, self.goal_form = box2d_variants[variant]

        self.num_actions = -1 if self.continuous_actions else 9 # discrete actions are a length 8 clock applying self.force_scaling in each direction
        self.name = "box2d"
        self.fixed_limits = fixed_limits # should be fixed to self.length, self.width
        self.discrete_actions = not self.continuous_actions
        self.transpose = False # should already be transposed TODO: I think
        self.render_masks = render_masks
        self.goal_idx = -1
        self.goal_graph_idx = -1
        if self.goal_epsilon > 0: 
            self.goal_based = True

        # spaces
        self.action_shape = (2,) if self.continuous_actions else (1,) # applies forces in x,y plane
        # print(self.continuous_actions)
        self.action_space = gym.spaces.Box(low=np.array([-1,-1]), high=np.array([1,1])) if self.continuous_actions else gym.spaces.Discrete(self.num_actions)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.render_size, self.render_size), dtype=np.uint8) # renders in squares
        self.pos_size = 2 # 2D domain uses the first two values for position

        # state components
        self.action = Action(self.continuous_actions)
        self.extracted_state = None

        # running values
        self.itr = 0
        self.total_itr = 0
        self.max_steps = horizon if horizon > 0 else 1e12

        # factorized state properties
        self.all_names, self.object_names, self.all_select_ids = self.generate_names() # generates based on initialization
        self.num_objects = len(self.all_names)
        self.object_sizes, self.object_range, self.object_dynamics, self.object_range_true, self.object_dynamics_true, self.position_masks = generate_object_dicts(self.continuous_actions, self.all_names, self.object_names, self.length, self.width) # must be initialized, a dictionary of name to length of the state
        # print(self.object_sizes, self.object_range, self.object_dynamics, self.object_range_true, self.object_dynamics_true, self.position_masks)
        self.object_instanced = self.generate_instancing() # TODO: support instancing
        self.object_proximal = {n: True for n in self.object_names if n not in ["Action", "Reward", "Done"]} # all objects support proximity, to varying degrees

        self.box2d_environment = init_box2d_environment(self.num_balls, self.poly_ids, self.num_poly_inst, 
                                                        self.dyn_damping, self.density, self.init_velocity, self.general_radius_min, self.general_radius_max,
                                                        self.use_target, self.target_form, self.target_damping, self.target_radius,
                                                        self.target_density, self.target_init_vel, self.control_damping, self.control_radius, self.control_density,
                                                        self.grav_domain, self.phyre_path, self.length, self.width, self.object_names, self.force_scaling, self.render_size, 
                                                        self.num_objects, self.object_instanced, render_masks=self.render_masks)
        self.reset()
        if self.goal_based: self.goal_idx = self.goal.idx
        # proximity state components
    
    def generate_instancing(self):
        if len(self.phyre_path) > 0:
            object_instanced = dict()
            # TODO: write this code
        else:
            object_instanced = dict()
            for n in self.object_names:
                if n == "Ball":
                    object_instanced[n] = self.num_balls
                elif n.find("Poly") != -1:
                    object_instanced[n] = self.num_poly_inst
                else: # only one of Control, Target, Action, Done, Reward
                    object_instanced[n] = 1
        return object_instanced
    
    def reset(self, goal=None, **kwargs):
        '''
        Most of the heavy lifting happens in the sub-function, assigns the valid
        and the wrapper to return the state component
        '''
        self.itr = 0
        if len(self.phyre_path) > 0:
            pass # TODO: call the reset function for phyre based environment
        else:
            # generate the valid objects
            self.action = Action(self.continuous_actions)
            num_select = (self.min_valid, self.max_valid + 1)
            # print(num_select, self.all_select_ids, self.all_names.index(self.force_live), self.force_live, self.all_names)
            self.instance_length = np.random.randint(max(0, num_select[0]), num_select[1]) # Action, reward and done are always valid (but not control)
            use_ids = copy.deepcopy(self.all_select_ids)
            if len(self.force_live) > 0: 
                self.instance_length -= 1
                use_ids.pop(self.all_names.index(self.force_live) - 1) # -1 is for Action, which is not in all_select_ids
            # print(CONTROL_ID, use_ids, self.all_select_ids, self.force_live, num_select)
            if len(use_ids) == 0: name_ids = ([self.all_select_ids[self.all_names.index(self.force_live) - 1]] if self.force_live else [])
            else: name_ids = np.random.choice(use_ids, size=self.instance_length, replace=False).tolist() + ([self.all_select_ids[self.all_names.index(self.force_live) - 1]] if self.force_live else [])
            # print(name_ids, self.instance_length, {"Control": [i for i in name_ids if i == CONTROL_ID], "Ball": [i for i in name_ids if i == BALL_ID], "Poly": [i-2 for i in name_ids if i in POLY_IDS]})
            self.box2d_environment.reset(type_instance_dict={"Control": [i for i in name_ids if i == CONTROL_ID], "Ball": [i for i in name_ids if i == BALL_ID], "Poly": [i-2 for i in name_ids if i in POLY_IDS], "Target": [i for i in name_ids if i == TARGET_ID]},
                                         max_count_dict={"Ball": self.num_balls, "Poly": self.num_poly_inst})
            self.generate_from_env() # generate from env samples a new goal (it creates a new goal object)
            if goal is not None: self.goal.attribute = goal
            # print(name_ids, use_ids, self.instance_length, self.valid_names)
        return self.get_state()

    def generate_from_env(self):
        self.valid_names = ["Action"] + (["Control"] if self.box2d_environment.control_ball_attrs is not None else []) + self.box2d_environment.ball_names + self.box2d_environment.poly_names + (["Target"] if self.box2d_environment.target_attrs is not None else []) + ["Reward", "Done"]
        object_dict = {n: Box2DObjWrapper(n, self.box2d_environment.object_dict[n] if n in self.box2d_environment.object_dict else None) for n in self.all_names}
        self.object_name_dict = {**object_dict, **{"Action": self.action, "Reward": self.reward, "Done": self.done}}
        self.objects = [self.object_name_dict[n] for n in self.all_names]

        goal = None
        if self.goal_epsilon >= 0:
            if not hasattr(self, "goal") or self.object_name_dict["Goal"] is None or self.goal is None:
                # add in the goal after other components have been added TODO: it isn't really necessary to do it here
                self.goal = Goal(self.goal_form, self.target_form, self.all_names, np.array([self.width-self.target_radius, self.length-self.target_radius]), self.goal_epsilon)
                self.object_name_dict["Goal"] = self.goal
                goal = self.goal.sample_goal()
                self.objects.insert(-3, self.goal)
                self.object_sizes["Goal"] = len(goal)
                self.object_range["Goal"] = [-self.goal.generate_bounds()[0], self.goal.generate_bounds()[0]]
                self.object_dynamics["Goal"] = [-np.ones(len(goal)) * 0.01, np.ones(len(goal)) * 0.01]
                self.object_range_true["Goal"] = [-self.goal.generate_bounds()[0], self.goal.generate_bounds()[0]]
                self.object_dynamics_true["Goal"] = [-np.ones(len(goal)) * 0.01, np.ones(len(goal)) * 0.01]
                self.position_masks["Goal"] = self.goal.generate_bounds()[1]
                self.goal_space = gym.spaces.Box(low=-self.goal.generate_bounds()[0], high=self.goal.generate_bounds()[0])
                self.goal_idx = self.goal.idx
                self.goal_graph_idx = self.goal.graph_idx
            else:
                self.object_name_dict["Goal"] = self.goal # TODO: not sure why we need to do this, but errors without it
                goal = self.goal.sample_goal()
            self.objects = [self.object_name_dict[n] for n in self.all_names]
        else:
            self.goal = None

    def generate_names(self):
        if len(self.phyre_path) > 0:
            return load_names_from_phyre(self.phyre_path)
        else:
            object_names = ["Action", "Control"] + (["Ball"] if self.num_balls > 0 else list()) + [preassigned_names[i] for i in self.poly_ids] + (["Target"] if self.use_target else list()) + ["Reward", "Done"] # TODO: could add boundaries, fixed objects
            all_names = ["Action", "Control"] + ["Ball" + str(i) if self.num_balls > 1 else "Ball" for i in range(self.num_balls)] + \
                        sum([[(preassigned_names[i] + str(j) if self.num_poly_inst > 1 else preassigned_names[i])
                              for j in range(self.num_poly_inst)] for i in self.poly_ids], start=list()) + \
                            (["Target"] if self.use_target else list()) + ["Reward", "Done"]
            all_select_ids = [CONTROL_ID] + [BALL_ID for i in range(self.num_balls)] + sum([[i + 2 for _ in range(self.num_poly_inst)] for i in self.poly_ids], start=list()) + ([TARGET_ID] if self.use_target else list())
            if self.goal_epsilon >= 0:
                object_names.insert(-2, "Goal")
                all_names.insert(-2, "Goal")
            return all_names, object_names, all_select_ids

    def convert_action(self, action):
        if not self.continuous_actions: action = np.array([np.cos((action-1) / 8 * 2 * np.pi), np.sin((action-1) / 8 * 2 * np.pi)]) if action != 0 else np.zeros(2) # zero action is no-op
        return action

    def reset_traces(self):
        for obj in self.objects:
            obj.interaction_trace = list()          

    def step(self, action, render= False, no_timeout=False):
        '''
        returns
            state as dict: next raw_state (image or observation) next factor_state (dictionary of name of object to tuple of object bounding box and object property)
            reward: the true reward from the environment
            done flag: if an episode ends, done is True
            info: a dict with additional info
        '''
        self.action.attribute = action
        self.done.attribute = False
        self.reset_traces()
        factor_graph = None
        achieved_goal = None
        desired_goal = None
        self.reward.attribute = 0.00
        self.done.attribute = False
        for i in range(self.frameskip):
            self.box2d_environment.step(action)
            contacts, contact_names = self.box2d_environment.get_contacts()
            # TODO: In theory, the traces can contain duplicates without issue 
            for n in contact_names:
                self.object_name_dict[n].interaction_trace += contact_names[n]
            self.object_name_dict["Control"].interaction_trace += ["Action"]
            # TODO: if reward or done is dependent, add traces here
            

            # set the achieved and desired goals, if used
            if "Goal" in self.object_name_dict:
                achieved_goal = self.goal.get_achieved_goal(self)
                desired_goal = self.goal.get_state()
                self.reward.attribute = float(self.goal.check_goal(self))
                if self.reward.attribute == 1: 
                    # self.done.attribute = True
                    self.goal.add_interaction(True)
            self.done.attribute = (self.itr == self.max_steps) or self.done.attribute
            rew, done = self.reward.attribute, self.done.attribute

            if self.done.attribute:
                factor_graph = self.get_factor_graph()
                self.reset()
                self.done.attribute, self.reward.attribute = True, rew
                break
        self.itr += 1
        return self.get_state(render= render), rew, done, self.get_info(factor_graph)

    def set_goal_params(self, goal_params):
        if self.goal is not None: 
            self.goal.goal_epsilon = goal_params["radius"]
            self.goal_epsilon = goal_params["radius"]

    def draw_goal(self, frame):
        # TODO: add goals that are not a single position
        # TODO: add individual mask to goal rendering
        if self.goal is not None:
            pos = self.goal.attribute[:2]
            center = np.array(pos) + np.array((self.width / 2, self.length/2))
            center = np.array((center[1], center[0])) * self.box2d_environment.ppm
            radius = int(self.goal.goal_epsilon * self.box2d_environment.ppm)
            if type(frame) == dict:
                addframe = np.zeros(frame["obs"].shape).astype(np.uint8)
                cv2.circle(addframe, center.astype(int), radius, (0,0,127), -1)
                frame["obs"] += addframe
                frame["Goal"] = addframe
                return frame
            else:
                addframe = np.zeros(frame.shape).astype(np.uint8)
                # print("goal", self.box2d_environment.ppm, self.length, self.width, radius, self.goal.goal_epsilon, addframe.shape, frame.shape)
                cv2.circle(addframe, center.astype(int), radius, (0,0,127), -1)
                return frame + addframe
        return frame

    def render(self, mode='human'):
        frame = self.box2d_environment.render()
        frame = self.draw_goal(frame)
        return frame

    def get_state(self, render= False):
        '''
        gets the image (raw state) and extracted state
        '''
        
        raw_state = self.render() if render else np.zeros((2,2)) # returns a dim 2 vector if not rendreing
        factored_state = {**{n: (np.round(self.object_name_dict[n].get_state(), decimals=5) if n != "Done" else self.object_name_dict[n].get_state()) for n in self.all_names}, 
                          **{"VALID_NAMES": self.valid_binary(self.valid_names), 
                            "TRACE": self.current_trace()}}
        return {"raw_state": raw_state, "factored_state": factored_state}

    def get_info(self, factor_graph=None): # returns the info, importantly the factor graph, achieved and desired goals
        achieved_goal, desired_goal, success = None, None, False
        if "Goal" in self.object_name_dict: achieved_goal, desired_goal, success = self.goal.get_achieved_goal(self), self.goal.get_state(), self.goal.check_goal(self)
        return {"factor_graph": self.get_factor_graph() if factor_graph is None else factor_graph, "achieved_goal": achieved_goal, "desired_goal": desired_goal, "success": success}

    def get_achieved_goal_state(self, state, fidx=None):
        return self.goal.get_achieved_goal_state(state, fidx=fidx)

    def get_itr(self):
        return self.itr

    def set_from_factored_state(self, factored_state, valid_names):
        # print("setting to", factored_state, valid_names)
        self.valid_names = valid_names
        self.instance_length = len(valid_names) - 3
        valid_filtered_factored = copy.deepcopy(factored_state)
        for n in factored_state.keys():
            if n not in self.valid_names: del valid_filtered_factored[n]
        # print(valid_filtered_factored)
        self.box2d_environment.reset(object_state_dict=valid_filtered_factored)
        self.generate_from_env()
        # print(self.valid_names)

    def demonstrate(self):
        '''
        gives an image and gets a keystroke action
        TODO: implement, box2D has opencv support for rendering
        '''
        return 0
